// Copyright 2020 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <vector>

#include "absl/flags/flag.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "xls/common/exit_status.h"
#include "xls/common/file/filesystem.h"
#include "xls/common/init_xls.h"
#include "xls/common/status/status_macros.h"
#include "xls/eco/ir_patch.pb.h"
#include "xls/eco/patch_ir.h"
#include "xls/ir/function_base.h"
#include "xls/ir/ir_parser.h"
#include "xls/ir/node.h"
#include "xls/ir/package.h"
#include "xls/scheduling/pipeline_schedule.h"
#include "xls/scheduling/pipeline_schedule.pb.h"
#include "xls/scheduling/scheduling_options.h"

static constexpr std::string_view kUsage = R"(
Ptaches a given IR with a given patch generated by ir_patch_gen.py. The
optional schedule file is used to constrain the schedule of the patched IR.

Example invocation:
  patch_ir_main -i <ir file> -p <patch file> -o <patched ir file> -s <schedule file>
)";

ABSL_FLAG(std::string, input_ir_path, "",
          "Path to the IR file to patch.");  // NOLINT
ABSL_FLAG(std::string, input_patch_path, "",
          "Path to the patch file.");  // NOLINT
ABSL_FLAG(std::string, output_ir_path, "",
          "Path to the patched IR file.");  // NOLINT
ABSL_FLAG(std::optional<std::string>, input_schedule_path, std::nullopt,
          "Path to the schedule file.");  // NOLINT

namespace xls {
static absl::StatusOr<PipelineSchedule> PipelineScheduleFromProto(
    FunctionBase* function,
    const PackagePipelineSchedulesProto& package_schedules_proto) {
  // This the modified version of PipelineSchedule::FromProto that ignores
  // extra nodes in the proto.
  const auto schedule_it =
      package_schedules_proto.schedules().find(function->name());
  if (schedule_it == package_schedules_proto.schedules().end()) {
    return absl::InvalidArgumentError("Function does not have a schedule.");
  }
  ScheduleCycleMap cycle_map;
  for (const auto& stage : schedule_it->second.stages()) {
    for (const auto& timed_node : stage.timed_nodes()) {
      // check if function has node
      if (function->GetNode(timed_node.node()).ok()) {
        XLS_ASSIGN_OR_RETURN(Node * node, function->GetNode(timed_node.node()));
        cycle_map[node] = stage.stage();
      }
    }
  }
  std::optional<int64_t> min_clock_period_ps;
  if (schedule_it->second.has_min_clock_period_ps()) {
    min_clock_period_ps = schedule_it->second.min_clock_period_ps();
  }
  return PipelineSchedule(function, cycle_map, /*length=*/std::nullopt,
                          min_clock_period_ps);
}
static absl::Status RealMain(
    const std::string& ir_path, const std::string& patch_path,
    const std::string& output_path,
    const std::optional<std::string>& input_schedule_path) {
  XLS_ASSIGN_OR_RETURN(std::string ir_data, GetFileContents(ir_path));
  XLS_ASSIGN_OR_RETURN(std::string patch_data, GetFileContents(patch_path));
  xls_eco::IrPatchProto patch;
  patch.ParseFromString(patch_data);
  XLS_ASSIGN_OR_RETURN(std::unique_ptr<Package> package,
                       Parser::ParsePackage(ir_data));
  FunctionBase* function_base = package->GetTop().value();
  std::optional<PipelineSchedule> schedule;
  PatchIr patch_ir(function_base, patch);
  XLS_RETURN_IF_ERROR(patch_ir.ApplyPatch());
  if (absl::GetFlag(FLAGS_input_schedule_path).has_value()) {
    XLS_ASSIGN_OR_RETURN(PackagePipelineSchedulesProto schedule_proto,
                         ParseTextProtoFile<PackagePipelineSchedulesProto>(
                             absl::GetFlag(FLAGS_input_schedule_path).value()));
    XLS_ASSIGN_OR_RETURN(
        schedule, PipelineScheduleFromProto(function_base, schedule_proto));
    XLS_RETURN_IF_ERROR(patch_ir.PatchSchedule(schedule.value()));
    XLS_RETURN_IF_ERROR(patch_ir.ExportScheduleProto());
  }
  auto output_dir = output_path.substr(0, output_path.find_last_of("/\\"));
  XLS_RETURN_IF_ERROR(!FileExists(output_dir).ok()
                          ? RecursivelyCreateDir(output_dir)
                          : absl::OkStatus());
  XLS_RETURN_IF_ERROR(patch_ir.ExportIr(output_path));
  return absl::OkStatus();
}
}  // namespace xls

int main(int argc, char** argv) {
  std::vector<std::string_view> positional_args =
      xls::InitXls(kUsage, argc, argv);
  return xls::ExitStatus(xls::RealMain(
      absl::GetFlag(FLAGS_input_ir_path), absl::GetFlag(FLAGS_input_patch_path),
      absl::GetFlag(FLAGS_output_ir_path),
      absl::GetFlag(FLAGS_input_schedule_path)));
}
